{
 "cells": [
  {
   "cell_type": "markdown",
   "id": "b389d4fd-41ec-4a53-b4de-ac8687c5fc73",
   "metadata": {},
   "source": [
    "# L3: Large Multimodal Models (LMMs)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "1f6dd19f",
   "metadata": {},
   "source": [
    "<p style=\"background-color:#fff6e4; padding:15px; border-width:3px; border-color:#f5ecda; border-style:solid; border-radius:6px\"> ⏳ <b>Note <code>(Kernel Starting)</code>:</b> This notebook takes about 30 seconds to be ready to use. You may start and watch the video while you wait.</p>"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "108c2ba6",
   "metadata": {},
   "source": [
    "* In this classroom, the libraries have been already installed for you.\n",
    "* If you would like to run this code on your own machine, you need to install the following:\n",
    "```\n",
    "    !pip install google-generativeai\n",
    "\n",
    "```\n",
    "\n",
    "Note: don't forget to set up your GOOGLE_API_KEY to use the Gemini Vision model in the env file.\n",
    "```\n",
    "   %env GOOGLE_API_KEY=************\n",
    "```\n",
    "Check the [documentation](https://ai.google.dev/gemini-api/docs/api-key) for more infomation."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "24f47148-c30f-4a0d-9d49-79e13e259c7d",
   "metadata": {
    "height": 47
   },
   "outputs": [],
   "source": [
    "import warnings\n",
    "warnings.filterwarnings('ignore')"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "5177f964-237d-41bf-9cc1-4dd358e5c26a",
   "metadata": {},
   "source": [
    "## Setup\n",
    "### Load environment variables and API keys"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "7b21ac31-961c-400e-8a70-e0b0ce367aa9",
   "metadata": {
    "height": 98
   },
   "outputs": [],
   "source": [
    "import os\n",
    "from dotenv import load_dotenv, find_dotenv\n",
    "\n",
    "_ = load_dotenv(find_dotenv()) # read local .env file\n",
    "GOOGLE_API_KEY=os.getenv('GOOGLE_API_KEY')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "47fab612-7eee-419f-9f2f-d49715640c3f",
   "metadata": {
    "height": 199
   },
   "outputs": [],
   "source": [
    "# Set the genai library\n",
    "import google.generativeai as genai\n",
    "from google.api_core.client_options import ClientOptions\n",
    "\n",
    "genai.configure(\n",
    "        api_key=GOOGLE_API_KEY,\n",
    "        transport=\"rest\",\n",
    "        client_options=ClientOptions(\n",
    "            api_endpoint=os.getenv(\"GOOGLE_API_BASE\"),\n",
    "        ),\n",
    ")"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "cb72cf5f",
   "metadata": {},
   "source": [
    "> Note: learn more about [GOOGLE_API_KEY](https://ai.google.dev/) to run it locally."
   ]
  },
  {
   "cell_type": "markdown",
   "id": "bfa5699c-118a-472e-9af1-cc588e5cec7d",
   "metadata": {},
   "source": [
    "## Helper functions"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "d71fc5b0-89cd-4fac-ba92-bcf8d398897b",
   "metadata": {
    "height": 149
   },
   "outputs": [],
   "source": [
    "import textwrap\n",
    "import PIL.Image\n",
    "from IPython.display import Markdown, Image\n",
    "\n",
    "def to_markdown(text):\n",
    "    text = text.replace('•', '  *')\n",
    "    return Markdown(textwrap.indent(text, '> ', predicate=lambda _: True))\n"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "64de3c92",
   "metadata": {},
   "source": [
    "* Function to call LMM (Large Multimodal Model)."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "d5127688-e33f-4e3c-b390-5bb21dc22f39",
   "metadata": {
    "height": 183
   },
   "outputs": [],
   "source": [
    "def call_LMM(image_path: str, prompt: str) -> str:\n",
    "    # Load the image\n",
    "    img = PIL.Image.open(image_path)\n",
    "\n",
    "    # Call generative model\n",
    "    model = genai.GenerativeModel('gemini-pro-vision')\n",
    "    response = model.generate_content([prompt, img], stream=True)\n",
    "    response.resolve()\n",
    "\n",
    "    return to_markdown(response.text)  "
   ]
  },
  {
   "cell_type": "markdown",
   "id": "aa6c3330-e9fd-49ce-acde-1d01ada7cbd4",
   "metadata": {},
   "source": [
    "## Analyze images with an LMM"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "0ffdd495-d552-4ea8-b537-d1e66cfe4647",
   "metadata": {
    "height": 47
   },
   "outputs": [],
   "source": [
    "# Pass in an image and see if the LMM can answer questions about it\n",
    "Image(url= \"SP-500-Index-Historical-Chart.jpg\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "4b034f67-b4a4-416b-a660-9b4d9fa6a741",
   "metadata": {
    "height": 64
   },
   "outputs": [],
   "source": [
    "# Use the LMM function\n",
    "call_LMM(\"SP-500-Index-Historical-Chart.jpg\", \n",
    "    \"Explain what you see in this image.\")"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "e13510df-c2ca-4bae-88e3-a425dd3c4f90",
   "metadata": {},
   "source": [
    "## Analyze a harder image"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "f87ecc37",
   "metadata": {},
   "source": [
    "* Try something harder: Here's a figure we explained previously!"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "092595da-a186-48ef-9fff-1f9e03b602a4",
   "metadata": {
    "height": 30
   },
   "outputs": [],
   "source": [
    "Image(url= \"clip.png\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "6884c0e0-0b7d-42de-b697-420d461e6336",
   "metadata": {
    "height": 47
   },
   "outputs": [],
   "source": [
    "call_LMM(\"clip.png\", \n",
    "    \"Explain what this figure is and where is this used.\")"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "74fbc266-3c54-44d6-ac98-0e8c90829347",
   "metadata": {},
   "source": [
    "## Decode the hidden message"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "7a155d96",
   "metadata": {},
   "source": [
    "<p style=\"background-color:#fff6ff; padding:15px; border-width:3px; border-color:#efe6ef; border-style:solid; border-radius:6px\"> 💻 &nbsp; <b>Access Utils File and Helper Functions:</b> To access the files for this notebook, 1) click on the <em>\"File\"</em> option on the top menu of the notebook and then 2) click on <em>\"Open\"</em>. For more help, please see the <em>\"Appendix - Tips and Help\"</em> Lesson.</p>\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "fbc87edc-bcf8-490c-aafc-61c2a3d306b7",
   "metadata": {
    "height": 30
   },
   "outputs": [],
   "source": [
    "Image(url= \"blankimage3.png\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "0b79706a-f336-41ca-96e0-211282b6d8a0",
   "metadata": {
    "height": 64
   },
   "outputs": [],
   "source": [
    "# Ask to find the hidden message\n",
    "call_LMM(\"blankimage3.png\", \n",
    "    \"Read what you see on this image.\")"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "6b8fa4c9-8576-4f45-9e8e-60723f75fdd4",
   "metadata": {},
   "source": [
    "## How the model sees the picture!"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "13ff5b03",
   "metadata": {},
   "source": [
    "> You have to be careful! The model does not \"see\" in the same way that we see!"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "d4c234f6-6ece-486a-ab9b-70e3082a1412",
   "metadata": {
    "height": 183
   },
   "outputs": [],
   "source": [
    "import imageio.v2 as imageio\n",
    "import numpy as np\n",
    "import matplotlib.pyplot as plt\n",
    "\n",
    "image = imageio.imread(\"blankimage3.png\")\n",
    "\n",
    "# Convert the image to a NumPy array\n",
    "image_array = np.array(image)\n",
    "\n",
    "plt.imshow(np.where(image_array[:,:,0]>120, 0,1), cmap='gray');"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "4df0fe88",
   "metadata": {},
   "source": [
    "### Try it yourself!"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "68a02c44",
   "metadata": {},
   "source": [
    "**EXTRA!**  You can use the function below to create your own hidden message, into an image:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "7eaaf218",
   "metadata": {
    "height": 149
   },
   "outputs": [],
   "source": [
    "# Create a hidden text in an image\n",
    "def create_image_with_text(text, font_size=20, font_family='sans-serif', text_color='#73D955', background_color='#7ED957'):\n",
    "    fig, ax = plt.subplots(figsize=(5, 5))\n",
    "    fig.patch.set_facecolor(background_color)\n",
    "    ax.text(0.5, 0.5, text, fontsize=font_size, ha='center', va='center', color=text_color, fontfamily=font_family)\n",
    "    ax.axis('off')\n",
    "    plt.tight_layout()\n",
    "    return fig"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "58373242",
   "metadata": {
    "height": 115
   },
   "outputs": [],
   "source": [
    "# Modify the text here to create a new hidden message image!\n",
    "fig = create_image_with_text(\"Hello, world!\") \n",
    "\n",
    "# Plot the image with the hidden message\n",
    "plt.show()\n",
    "fig.savefig(\"extra_output_image.png\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "3b153c37",
   "metadata": {
    "height": 64
   },
   "outputs": [],
   "source": [
    "# Call the LMM function with the image just generated\n",
    "call_LMM(\"extra_output_image.png\", \n",
    "    \"Read what you see on this image.\")"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "a259b951",
   "metadata": {},
   "source": [
    "* It worked!, now plot the image decoding the message."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "f1c76e87",
   "metadata": {
    "height": 115
   },
   "outputs": [],
   "source": [
    "image = imageio.imread(\"extra_output_image.png\")\n",
    "\n",
    "# Convert the image to a NumPy array\n",
    "image_array = np.array(image)\n",
    "\n",
    "plt.imshow(np.where(image_array[:,:,0]>120, 0,1), cmap='gray');"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.11.9"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
